/*
# Copyright (c) 2022 Huawei Device Co., Ltd.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/

#include <array>
#include <map>
#include <iostream>

#include <seeta/Struct.h>
#include <seeta/QualityAssessor.h>
#include <seeta/FaceDetector.h>
#include <seeta/FaceLandmarker.h>
#include <seeta/FaceDatabase.h>

#include <opencv2/opencv.hpp>

#include "Struct_cv.h"

using namespace std;

#define HILOGI(fmt, args...) printf("[DEBUG][%s|%d]" fmt, __func__, __LINE__, ##args)
#define HILOGE(fmt, args...) printf("[ERROR][%s|%d]" fmt, __func__, __LINE__, ##args)

#define MIN_FACE_SIZE  80

vector<string> GalleryImageFilename = {
    "./xiaoming.jpg",
    "./xiaohong.jpg",
    "./xiaozhu.jpg"
};

vector<SeetaFaceInfo> DetectFace(seeta::FaceDetector &FD, const SeetaImageData &image)
{
    auto faces = FD.detect(image);

    return vector<SeetaFaceInfo>(faces.data, faces.data + faces.size);
}

vector<SeetaPointF> DetectPoints(seeta::FaceLandmarker &PD, const SeetaImageData &image, const SeetaRect &face)
{
    vector<SeetaPointF> points(PD.number());
    PD.mark(image, face, points.data());

    return move(points);
}

static void RegisterFace(seeta::FaceDetector &FD, seeta::FaceLandmarker &PD, seeta::FaceDatabase &FDB,
                         map<int64_t, string> &GalleryIndexMap)
{
    HILOGI("start Registering: \r\n");

    FD.set(seeta::FaceDetector::PROPERTY_MIN_FACE_SIZE, MIN_FACE_SIZE);
    vector<int64_t> GalleryIndex(GalleryImageFilename.size());
    for (size_t i = 0; i < GalleryImageFilename.size(); ++i) {
        string &filename = GalleryImageFilename[i];
        int64_t &index = GalleryIndex[i];
        HILOGI("Registering %s \r\n", filename.c_str());
        seeta::cv::ImageData image = cv::imread(filename);
        vector<SeetaFaceInfo> faces =  DetectFace(FD, image);
        vector<SeetaPointF> points =  DetectPoints(PD, image, faces[0].pos);
        auto id = FDB.Register(image, points.data());
        index = id;
        HILOGI("Registering id=%d \r\n", id);
    }

    for (size_t i = 0; i < GalleryIndex.size(); ++i) {
        // save index and name pair
        if (GalleryIndex[i] < 0) {
            continue;
        }
        GalleryIndexMap.insert(make_pair(GalleryIndex[i], GalleryImageFilename[i]));
    }
    HILOGI("Register end !! \r\n");
}

static int64_t RecognizeFace(seeta::FaceDetector &FD, seeta::FaceLandmarker &PD, seeta::FaceDatabase &FDB,
                             map<int64_t, string> &GalleryIndexMap, const char *filname)
{
//    HILOGI("start Recognize: \r\n");

    seeta::QualityAssessor QA;
    float threshold = 0.7f;     // 识别阈值

    cv::Mat frame = cv::imread(filname);
    seeta::cv::ImageData image = frame;

    auto f = FD.detect(image);
    vector<SeetaFaceInfo> faces = vector<SeetaFaceInfo>(f.data, f.data + f.size);

    for (SeetaFaceInfo &face : faces) {
        int64_t index = -1;
        float similarity = 0;
        string name = "";

        vector<SeetaPointF> points(PD.number());
        PD.mark(image, face.pos, points.data());    // 获取人脸框信息
        auto score = QA.evaluate(image, face.pos, points.data());   // 获取人脸质量评分
        if (score == 0) {
            name = "ignored";
        } else {
            auto queried = FDB.QueryTop(image, points.data(), 1, &index, &similarity);   // 从注册的人脸数据库中对比相似度
            if (queried < 1) {
                continue;
            }
            if (similarity > threshold) {
                name = GalleryIndexMap[index];
            }
        }

        if (!name.empty()) {
            HILOGI("get result! name : %s\r\n", name.c_str());
        } else {
            HILOGE("the pictrue haven't any faces!\r\n");
        }
    }

//    HILOGI("Recognize end!!\r\n");

    return 0;
}

int main(int argc, char **argv)
{
    if (argc == 1) {
        HILOGE("usage %s pictrue name to recognise!!\r\n", argv[0]);
        return -1;
    }
    const char *recognizeFile = argv[1];

    seeta::ModelSetting::Device device = seeta::ModelSetting::CPU;
    int id = 0;

    seeta::ModelSetting FD_model("./model/fd_2_00.dat", device, id);
    seeta::ModelSetting PD_model("./model/pd_2_00_pts5.dat", device, id);
    seeta::ModelSetting FR_model("./model/fr_2_10.dat", device, id);

    seeta::FaceDetector FD(FD_model);    // 创建人脸检测模块
    seeta::FaceLandmarker PD(PD_model);   // 创建面部关键点定位模块
    seeta::FaceDatabase FDB(FR_model);    // 创建人脸特征信息数据库模块

    map<int64_t, string> GalleryIndexMap;

    RegisterFace(FD, PD, FDB, GalleryIndexMap);
    RecognizeFace(FD, PD, FDB, GalleryIndexMap, recognizeFile);

    return 0;
}
